/**
 * Deep4jNetworkWrapper.java created by zhangzhidong 
 * at 上午10:07:50 2017年5月26日
 */
package cn.edu.bjtu.model.core;

import java.io.Serializable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;

import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import cn.edu.bjtu.api.NetworkModelService;
import cn.edu.bjtu.api.NeuronNetwork;
import cn.edu.bjtu.core.DefinedThreadFactory;

/**
 * 包装类,支持更新时间获取
 * @author Alex
 * 
 */
public class Deep4jNetworkWrapper extends AbstractNetWorkWrapper implements NeuronNetwork,Serializable{
	/**
	 * 
	 */
	private static final long serialVersionUID = 1L;
	//dubbo remote service
	
	private int intCpuNumber = -1;
	private ThreadFactory _tf = new DefinedThreadFactory();
	private ExecutorService es;
	/**
	 * @param nms 读取神经网络模型所用的类
	 * @param cpuNum 支持多少个线程并发的去运算,如果-1,则为availableProcessors();
	 */
	public Deep4jNetworkWrapper(NetworkModelService nms,int cpuNum) {
		this.nms = nms;
		int availabelNum = Runtime.getRuntime().availableProcessors();
		intCpuNumber = (cpuNum < 0 ?availabelNum:(cpuNum>availabelNum?availabelNum:cpuNum));
		es = Executors.newFixedThreadPool(intCpuNumber,_tf);
	}
	@Override
	public INDArray[] output(INDArray... input) throws Exception{
		Future<INDArray[]> f = es.submit(()->slave.output(input));
		return f.get();
	}	
	@Override
	public Future<INDArray[]> outputAsync(INDArray... input) throws Exception {
		return es.submit(()->slave.output(input));
	}
	 @Override
	public Evaluation evaluate(DataSetIterator iterator) throws Exception {
		Future<Evaluation> f = es.submit(()-> slave.evaluate(iterator));
		return f.get();
	}
	 
	@Override
	public Future<Evaluation> evaluateAsync(DataSetIterator iterator) throws Exception {
		return es.submit(()-> slave.evaluate(iterator));
	}


	

}
